/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * @addtogroup MindSpore
 * @{
 * 
 * @brief Provides APIs related to MindSpore Lite model inference.
 * 
 * @Syscap SystemCapability.Ai.MindSpore
 * @since 9
 */

/**
 * @file model.h
 * 
 * @brief Provides model-related APIs for model creation and inference.
 * 
 * File to include: <mindspore/model.h>
 * @library libmindspore_lite_ndk.so
 * @since 9
 */

#ifndef MINDSPORE_INCLUDE_C_API_MODEL_C_H
#define MINDSPORE_INCLUDE_C_API_MODEL_C_H

#include "mindspore/tensor.h"
#include "mindspore/context.h"
#include "mindspore/status.h"

#ifdef __cplusplus
extern "C"
{
#endif
/**
 * @brief Defines the pointer to a model object.
 *
 * @since 9
 */
typedef void *OH_AI_ModelHandle;

/**
 * @brief Defines the pointer to a training configuration object.
 *
 * @since 11
 */
typedef void *OH_AI_TrainCfgHandle;

/**
 * @brief Defines the tensor array structure, which is used to store the tensor array pointer
 * and tensor array length.
 *
 * @since 9
 */
typedef struct OH_AI_TensorHandleArray
{
  /** Tensor array length */
  size_t handle_num;
  /** Tensor array pointer */
  OH_AI_TensorHandle *handle_list;
} OH_AI_TensorHandleArray;

/**
 * @brief Defines the maximum tensor dimension.
 *
 * @since 9
 */
#define OH_AI_MAX_SHAPE_NUM 32

/**
 * @brief Defines dimension information. The maximum dimension is set by {@link OH_AI_MAX_SHAPE_NUM}.
 *
 * @since 9
 */
typedef struct OH_AI_ShapeInfo
{
  /** Dimension array length */
  size_t shape_num;
  /** Dimension array */
  int64_t shape[OH_AI_MAX_SHAPE_NUM];
} OH_AI_ShapeInfo;

/**
 * @brief Defines the operator information passed in a callback.
 * 
 * @since 9
 */
typedef struct OH_AI_CallBackParam
{
  /** Operator name */
  char *node_name;
  /** Operator type */
  char *node_type;
} OH_AI_CallBackParam;

/**
 * @brief Defines the pointer to a callback.
 *
 * This pointer is used to set the two callback functions in {@link OH_AI_ModelPredict}.
 * Each callback function must contain three parameters, where **inputs** and **outputs** indicate
 * the input and output tensors of the operator, and **kernel_Info** indicates information about
 * the current operator.
 * You can use the callback functions to monitor the operator execution status, for example,
 * operator execution time and the operator correctness.
 * 
 * @since 9
 */
typedef bool (*OH_AI_KernelCallBack)(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
                                      const OH_AI_CallBackParam kernel_Info);

/**
 * @brief Creates a model object.
 *
 * @return Pointer to the model object.
 * @since 9
 */
OH_AI_API OH_AI_ModelHandle OH_AI_ModelCreate();

/**
 * @brief Destroys a model object.
 *
 * @param model Pointer to the model object.
 * @since 9
 */
OH_AI_API void OH_AI_ModelDestroy(OH_AI_ModelHandle *model);

/**
 * @brief Loads and builds a MindSpore model from the memory buffer.
 *
 * Note that the same {@link OH_AI_ContextHandle} object can be passed to {@link OH_AI_ModelBuild} or
 * {@link OH_AI_ModelBuildFromFile} only once.
 * If you call this function multiple times, make sure that you create multiple
 * {@link OH_AI_ContextHandle} objects accordingly.
 * 
 * @param model Pointer to the model object.
 * @param model_data Address of the loaded model data in the memory.
 * @param data_size Length of the model data.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_context Context for model running. For details, see {@link OH_AI_ContextHandle}.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 9
 */
OH_AI_API OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
                                        OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context);

/**
 * @brief Loads and builds a MindSpore model from a model file.
 *
 * Note that the same {@link OH_AI_ContextHandle} object can be passed to {@link OH_AI_ModelBuild} or
 * {@link OH_AI_ModelBuildFromFile} only once.
 * If you call this function multiple times, make sure that you create multiple
 * {@link OH_AI_ContextHandle} objects accordingly.
 * 
 * @param model Pointer to the model object.
 * @param model_path Path of the model file.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_context Context for model running. For details, see {@link OH_AI_ContextHandle}.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 9
 */
OH_AI_API OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
                                                OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context);

/**
 * @brief Adjusts the input tensor shapes of a built model.
 *
 * @param model Pointer to the model object.
 * @param inputs Tensor array structure corresponding to the model input.
 * @param shape_infos Input shape array, which consists of tensor shapes arranged in the model input sequence.
 * The model adjusts the tensor shapes in sequence.
 * @param shape_info_num Length of the input shape array.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 9
 */
OH_AI_API OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
                                          OH_AI_ShapeInfo *shape_infos, size_t shape_info_num);

/**
 * @brief Performs model inference.
 *
 * @param model Pointer to the model object.
 * @param inputs Tensor array structure corresponding to the model input.
 * @param outputs Pointer to the tensor array structure corresponding to the model output.
 * @param before Callback function executed before model inference.
 * @param after Callback function executed after model inference.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 9
 */
OH_AI_API OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
                                          OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before,
                                          const OH_AI_KernelCallBack after);

/**
 * @brief Obtains the input tensor array structure of a model.
 *
 * @param model Pointer to the model object.
 * @return Tensor array structure corresponding to the model input.
 * @since 9
 */
OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model);

/**
 * @brief Obtains the output tensor array structure of a model.
 *
 * @param model Pointer to the model object.
 * @return Tensor array structure corresponding to the model output.
 * @since 9
 */
OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model);

/**
 * @brief Obtains the input tensor of a model by tensor name.
 *
 * @param model Pointer to the model object.
 * @param tensor_name Tensor name.
 * @return Pointer to the input tensor indicated by **tensor_name**. If the tensor does not exist in the input,
 * **null** will be returned.
 * @since 9
 */
OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);

/**
 * @brief Obtains the output tensor of a model by tensor name.
 *
 * @param model Pointer to the model object.
 * @param tensor_name Tensor name.
 * @return Pointer to the output tensor indicated by **tensor_name**. If the tensor does not exist in the input,
 * **null** will be returned.
 * @since 9
 */
OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);

/**
 * @brief Creates the pointer to the training configuration object. This API is used only for on-device training.
 * @return Pointer to the training configuration object.
 * @since 11
 */
OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate();

/**
 * @brief Destroys the pointer to the training configuration object. This API is used only for on-device training.
 *
 * @param train_cfg Pointer to the training configuration object.
 * @since 11
 */
OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg);

/**
 * @brief Obtains the list of loss functions, which are used only for on-device training.
 *
 * @param train_cfg Pointer to the training configuration object.
 * @param num Number of loss functions.
 * @return List of loss functions.
 * @since 11
 */
OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num);

/**
 * @brief Sets the list of loss functions, which are used only for on-device training.
 *
 * @param train_cfg Pointer to the training configuration object.
 * @param loss_name List of loss functions.
 * @param num Number of loss functions.
 * @since 11
 */
OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num);

/**
 * @brief Obtains the optimization level of the training configuration object. This API is used only for
 * on-device training.
 *
 * @param train_cfg Pointer to the training configuration object.
 * @return Optimization level.
 * @since 11
 */
OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg);

/**
 * @brief Sets the optimization level of the training configuration object. This API is used only for
 * on-device training.
 *
 * @param train_cfg Pointer to the training configuration object.
 * @param level Optimization level.
 * @since 11
 */
OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level);

/**
 * @brief Loads a training model from the memory buffer and compiles the model to a state ready for
 * running on the device. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param model_data Pointer to the buffer that stores the model file to read.
 * @param data_size Buffer size.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_context Context for model running. For details, see {@link OH_AI_ContextHandle}.
 * @param train_cfg Pointer to the training configuration object.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
                                             OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
                                             const OH_AI_TrainCfgHandle train_cfg);

/**
 * @brief Loads the training model from the specified path and compiles the model to a state ready for
 * running on the device. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param model_path Path of the model file.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_context Context for model running. For details, see {@link OH_AI_ContextHandle}.
 * @param train_cfg Pointer to the training configuration object.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
                                                     OH_AI_ModelType model_type,
                                                     const OH_AI_ContextHandle model_context,
                                                     const OH_AI_TrainCfgHandle train_cfg);

/**
 * @brief Defines a single-step training model. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param before Callback function executed before model inference.
 * @param after Callback function executed after model inference.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before,
                                     const OH_AI_KernelCallBack after);

/**
 * @brief Sets the learning rate for model training. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param learning_rate Learning rate.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate);

/**
 * @brief Obtains the learning rate for model training. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @return Learning rate. If no optimizer is set, the value is <b>0.0</b>.
 * @since 11
 */
OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model);

/**
 * @brief Obtains all weight tensors of a model. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @return All weight tensors of the model.
 * @since 11
 */
OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model);

/**
 * @brief Updates the weight tensors of a model. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param new_weights Weight tensors to update.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights);

/**
 * @brief Obtains the training mode.
 *
 * @param model Pointer to the model object.
 * @return Whether the training mode is used.
 * @since 11
 */
OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model);

/**
 * @brief Sets the training mode. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param train Whether the training mode is used.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train);

/**
 * @brief Sets the virtual batch for training. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param virtual_batch_multiplier Virtual batch multiplier. If the value is less than <b>1</b>,
 * the virtual batch is disabled.
 * @param lr Learning rate. The default value is <b>-1.0f</b>.
 * @param momentum Momentum. The default value is <b>-1.0f</b>.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr,
                                                    float momentum);

/**
 * @brief Exports a training model. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_file Path of the exported model file.
 * @param quantization_type Quantization type.
 * @param export_inference_only Whether to export inference models.
 * @param output_tensor_name Output tensor of the exported model. This parameter is left blank by default,
 * which indicates full export.
 * @param num Number of output tensors.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
                                         OH_AI_QuantizationType quantization_type, bool export_inference_only,
                                         char **output_tensor_name, size_t num);

/**
 * @brief Exports the memory cache of the training model. This API is used only for on-device training.
 *
 * @param model Pointer to the model object.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param model_data Pointer to the buffer that stores the exported model file.
 * @param data_size Buffer size.
 * @param quantization_type Quantization type.
 * @param export_inference_only Whether to export inference models.
 * @param output_tensor_name Output tensor of the exported model. This parameter is left blank by default,
 * which indicates full export.
 * @param num Number of output tensors.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
                                               size_t *data_size, OH_AI_QuantizationType quantization_type,
                                               bool export_inference_only, char **output_tensor_name, size_t num);

/**
 * @brief Exports the weight file of the training model for micro inference. This API is used only for
 * on-device training.
 *
 * @param model Pointer to the model object.
 * @param model_type Type of the model file. For details, see {@link OH_AI_ModelType}.
 * @param weight_file Path of the exported weight file.
 * @param is_inference Whether to export inference models. Currently, this parameter can only be set to <b>true</b>.
 * @param enable_fp16 Whether to save floating-point weights in float16 format.
 * @param changeable_weights_name Name of the weight tensor with a variable shape.
 * @param num Number of weight tensors with a variable shape.
 * @return Status code enumerated by {@link OH_AI_Status}. The value **OH_AI_Status::OH_AI_STATUS_SUCCESS**
 * indicates that the operation is successful.
 * @since 11
 */
OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type,
                                                               const char *weight_file, bool is_inference,
                                                               bool enable_fp16, char **changeable_weights_name,
                                                               size_t num);

#ifdef __cplusplus
}
#endif

/** @} */
#endif // MINDSPORE_INCLUDE_C_API_MODEL_C_H
